"""
@author : linrh
@homepage : https://gitee.com/linrh-DUT
@version: 1.0.0
@when : 2023/5/18
@file: data.py
"""
from conf import *
from util.data_loader import DataLoader
from util.tokenizer import Tokenizer

tokenizer = Tokenizer()
loader = DataLoader(pair=('en', 'de'),
                    tokenize_source=tokenizer.tokenize_en,
                    tokenize_target=tokenizer.tokenize_de)

train, valid, test = loader.make_dataset()
loader.build_vocab(min_freq=2)
train_iter, valid_iter, test_iter = loader.make_iter(train, valid, test,
                                                     batch_size=batch_size,
                                                     device=device)

# 填充索引
src_pad_idx = loader.PAD_IDX
trg_pad_idx = loader.PAD_IDX
trg_sos_idx = loader.BOS_IDX

# 词典大小
enc_voc_size = len(loader.vocab_source)
dec_voc_size = len(loader.vocab_target)
